import copy
from utils import *
import time
import random

class CliffWalkingEnv:
    """悬崖漫步环境"""
    def __init__(self,ncol=12,nrow=4):
        self.ncol = ncol
        self.nrow = nrow
        # 转移矩阵P[state][action] = [(p,next_state,reward,done)]包含下一个状态和奖励
        self.P = self.createP()
    
    def createP(self):
        # 初始化
        P = [[[] for j in range(4)] for i in range (self.nrow * self.ncol)]
        # 4种动作，change[0]:上，change[1]:下，change[2]:左，change[3]:右。坐标原点(0,0)
        # 定义在左上角
        change = [[0,-1],[0,1],[-1,0],[1,0]]
        for i in range(self.nrow):
            for j in range(self.ncol):
                for a in range(4):  # (0:上 1:下 2:左 3:右)
                    # 位置在悬崖或者目标状态，因为无法继续交互，任何动作奖励都为0
                    if i == self.nrow-1 and j>0:
                        P[i*self.ncol+j][a] = [(1,i*self.ncol+j,0,True)]
                        continue
                    # 其他位置
                    next_x = min(self.ncol - 1,max(0,j+change[a][0]))
                    next_y = min(self.nrow-1,max(0,i+change[a][1]))
                    next_state = next_y * self.ncol+next_x
                    reward = -1
                    done = False
                    # 下一个位置在悬崖或终点
                    if next_y == self.nrow-1 and next_x >0:
                        done = True
                        if next_x != self.ncol-1: # 在悬崖
                            reward = -100
                    P[i*self.ncol+j][a] = [(1,next_state,reward,done)]
        return P

class PolicyIteration:
    """策略迭代算法"""
    def __init__(self,env,theta,gamma):
        self.env = env
        self.v = [0] * self.env.ncol*self.env.nrow  # 初始化价值为0
        self.pi = [[0.25,0.25,0.25,0.25] for i in range(self.env.ncol * self.env.nrow)]  # 初始化为均匀策略
        self.theta = theta  # 策略评估收敛阈值
        self.gamma = gamma  # 折扣因子
    
    def policy_evaluation(self):  # 策略评估
        """策略评估的目的是在给定当前策略 𝜋 的情况下，计算每个状态价值函数V(s)"""
        cnt = 1  # 计数器
        while 1:
            max_diff = 0
            new_v = [0] * self.env.ncol * self.env.nrow    # 创建一个新的状态值函数表
            for s in range(self.env.ncol * self.env.nrow):  # 遍历所有状态
                qsa_list = []  # 开始计算状态s下的所有Q(s,a)价值，存储 Q(s, a) 值
                for a in range(4):  # 遍历所有动作
                    qsa = 0
                    for res in self.env.P[s][a]:  # 计算Q(s,a)
                        p,next_state,r,done = res
                        qsa += p * (r + self.gamma * self.v[next_state] * (1-done))
                    # 本章环境特殊，奖励和下一个状态有关，所以需要和装填转移概率相乘
                    qsa_list.append(self.pi[s][a] * qsa)  # 按策略概率计算期望
                new_v[s] = sum(qsa_list)  # 状态价值函数和动作价值函数之间的关系，计算状态价值
                max_diff = max(max_diff,abs(new_v[s]-self.v[s]))
            self.v = new_v
            if max_diff < self.theta: break  # 满足收敛条件，退出评估迭代
            cnt += 1
        print(f"策略评估进行{cnt}轮后完成")
    
    def policy_improvement(self):  # 策略提升
        """策略提升的目的是在给定当前状态价值函数V(s)的情况下，计算最优策略 𝜋"""
        for s in range(self.env.nrow * self.env.ncol):
            qsa_list = []
            for a in range(4):
                qsa = 0
                for res in self.env.P[s][a]:
                    p,next_state,r,done = res
                    qsa += p * (r+self.gamma * self.v[next_state] * (1-done))  # (1-done)表示是否终止，若终止，则V(s')的贡献为0，否则正常计算
                qsa_list.append(qsa)
            maxq = max(qsa_list)   # 找到最大的 Q 值
            cntq = qsa_list.count(maxq)  # 计算有几个动作得到了最大的Q值
            # 让这些动作均分概率
            self.pi[s] = [1/cntq if q == maxq else 0 for q in qsa_list]
        print("策略提升完成")
        return self.pi
    
    def policy_iteration(self):   # 策略迭代
        while 1:
            self.policy_evaluation()
            old_pi = copy.deepcopy(self.pi)  # 将列表进行深拷贝，方便接下来进行比较
            new_pi = self.policy_improvement()
            if old_pi == new_pi: break

class ValueIteration:
    """价值迭代算法"""
    def __init__(self,env,theta,gamma):
        self.env = env
        self.v = [0] * self.env.ncol*self.env.nrow  # 初始化价值为0
        self.theta = theta  # 价值收敛阈值
        self.gamma = gamma
        # 价值迭代结束后得到的策略
        self.pi = [None for i in range(self.env.ncol * self.env.nrow)]

    def value_iteration(self):
        cnt = 0
        while 1:
            max_diff = 0
            new_v = [0] * self.env.ncol * self.env.nrow
            for s in range(self.env.ncol * self.env.nrow):
                qsa_list = []  # 开始计算状态s下的所有Q(s,a)价值，存储 Q(s, a) 值
                for a in range(4):
                    qsa = 0
                    for res in self.env.P[s][a]:
                        p,next_state,r,done = res
                        qsa += p* (r+self.gamma*self.v[next_state]*(1-done))
                    qsa_list.append(qsa)
                new_v[s] = max(qsa_list)
                max_diff = max(max_diff,abs(new_v[s]-self.v[s]))
            self.v = new_v
            if max_diff < self.theta:break  # 满足收敛条件，退出评估迭代
            cnt += 1
        print(f"价值迭代进行{cnt}轮后完成")
        self.get_policy()

    def get_policy(self):  # 根据价值函数导出一个贪婪策略
        for s in range(self.env.ncol * self.env.nrow):
            qsa_list = []
            for a in range(4):
                qsa = 0
                for res in self.env.P[s][a]:
                    p,next_state,r,done = res
                    qsa += p * (r+self.gamma*self.v[next_state]*(1-done))
                qsa_list.append(qsa)
            maxq = max(qsa_list)
            cntq = qsa_list.count(maxq)  # 计算有几个动作得到了最大的Q值
            # 让这些动作均分概率
            self.pi[s] = [1/cntq if q == maxq else 0 for q in qsa_list]

def print_agent(agent, action_meaning, disaster=[], end=[]):
    print("状态价值：")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            # 为了输出美观,保持输出6个字符
            print('%6.6s' % ('%.3f' % agent.v[i * agent.env.ncol + j]),
                  end=' ')
        print()

    print("策略：")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            # 一些特殊的状态,例如悬崖漫步中的悬崖
            if (i * agent.env.ncol + j) in disaster:
                print('****', end=' ')
            elif (i * agent.env.ncol + j) in end:  # 目标状态
                print('EEEE', end=' ')
            else:
                a = agent.pi[i * agent.env.ncol + j]
                pi_str = ''
                for k in range(len(action_meaning)):
                    pi_str += action_meaning[k] if a[k] > 0 else 'o'
                print(pi_str, end=' ')
        print()


if __name__ == '__main__':
    random.seed(0)
    env = CliffWalkingEnv()
    action_meaning = ['^','v','<','>']
    theta = 0.001
    gamma = 0.9

    agent = PolicyIteration(env,theta,gamma)
    start = time.time()
    agent.policy_iteration()
    end = time.time()
    print_agent(agent,action_meaning,disaster=list(range(37,47)),end=[47])
    print(f"策略迭代算法用时{end-start}s")

    # agent = ValueIteration(env,theta,gamma)
    # agent.value_iteration()
    # print_agent(agent,action_meaning,disaster=list(range(37,47)),end=[47])

